{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "laa9tRjJ59bl"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "T4ZHtBpK6Dom"
      },
      "outputs": [],
      "source": [
        "#@title Copyright 2020 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hk5u_9KN1m-t"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/hub/tutorials/yamnet\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/hub/tutorials/yamnet.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/hub/tutorials/yamnet.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/hub/tutorials/yamnet.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://tfhub.dev/google/yamnet/1\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x2ep-q7k_5R-"
      },
      "source": [
        "# Sound classification with YAMNet\n",
        "\n",
        "YAMNet is a deep net that predicts 521 audio event [classes](https://github.com/tensorflow/models/blob/master/research/audioset/yamnet/yamnet_class_map.csv) from the [AudioSet-YouTube corpus](http://g.co/audioset) it was trained on. It employs the\n",
        "[Mobilenet_v1](https://arxiv.org/pdf/1704.04861.pdf) depthwise-separable\n",
        "convolution architecture."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bteu7pfkpt_f"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "import numpy as np\n",
        "import csv\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from IPython.display import Audio\n",
        "from scipy.io import wavfile"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YSVs3zRrrYmY"
      },
      "source": [
        "Load the Model from TensorFlow Hub.\n",
        "\n",
        "Note: to read the documentation just follow the model's [url](https://tfhub.dev/google/yamnet/1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VX8Vzs6EpwMo"
      },
      "outputs": [],
      "source": [
        "# Load the model.\n",
        "model = hub.load('https://tfhub.dev/google/yamnet/1')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lxWx6tOdtdBP"
      },
      "source": [
        "The labels file will be loaded from the models assets and is present at `model.class_map_path()`.\n",
        "You will load it on the `class_names` variable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EHSToAW--o4U"
      },
      "outputs": [],
      "source": [
        "# Find the name of the class with the top score when mean-aggregated across frames.\n",
        "def class_names_from_csv(class_map_csv_text):\n",
        "  \"\"\"Returns list of class names corresponding to score vector.\"\"\"\n",
        "  class_names = []\n",
        "  with tf.io.gfile.GFile(class_map_csv_text) as csvfile:\n",
        "    reader = csv.DictReader(csvfile)\n",
        "    for row in reader:\n",
        "      class_names.append(row['display_name'])\n",
        "\n",
        "  return class_names\n",
        "\n",
        "class_map_path = model.class_map_path().numpy()\n",
        "class_names = class_names_from_csv(class_map_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mSFjRwkZ59lU"
      },
      "source": [
        "Add a method to verify and convert a loaded audio is on the proper sample_rate (16K), otherwise it would affect the model's results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LizGwWjc5w6A"
      },
      "outputs": [],
      "source": [
        "def ensure_sample_rate(original_sample_rate, waveform,\n",
        "                       desired_sample_rate=16000):\n",
        "  \"\"\"Resample waveform if required.\"\"\"\n",
        "  if original_sample_rate != desired_sample_rate:\n",
        "    desired_length = int(round(float(len(waveform)) /\n",
        "                               original_sample_rate * desired_sample_rate))\n",
        "    waveform = scipy.signal.resample(waveform, desired_length)\n",
        "  return desired_sample_rate, waveform"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AZEgCobA9bWl"
      },
      "source": [
        "## Downloading and preparing the sound file\n",
        "\n",
        "Here you will download a wav file and listen to it.\n",
        "If you have a file already available, just upload it to colab and use it instead.\n",
        "\n",
        "Note: The expected audio file should be a mono wav file at 16kHz sample rate."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WzZHvyTtsJrc"
      },
      "outputs": [],
      "source": [
        "!curl -O https://storage.googleapis.com/audioset/speech_whistling2.wav"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D8LKmqvGzZzr"
      },
      "outputs": [],
      "source": [
        "!curl -O https://storage.googleapis.com/audioset/miaow_16k.wav"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wo9KJb-5zuz1"
      },
      "outputs": [],
      "source": [
        "# wav_file_name = 'speech_whistling2.wav'\n",
        "wav_file_name = 'miaow_16k.wav'\n",
        "sample_rate, wav_data = wavfile.read(wav_file_name, 'rb')\n",
        "sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)\n",
        "\n",
        "# Show some basic information about the audio.\n",
        "duration = len(wav_data)/sample_rate\n",
        "print(f'Sample rate: {sample_rate} Hz')\n",
        "print(f'Total duration: {duration:.2f}s')\n",
        "print(f'Size of the input: {len(wav_data)}')\n",
        "\n",
        "# Listening to the wav file.\n",
        "Audio(wav_data, rate=sample_rate)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P9I290COsMBm"
      },
      "source": [
        "The `wav_data` needs to be normalized to values in `[-1.0, 1.0]` (as stated in the model's [documentation](https://tfhub.dev/google/yamnet/1))."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bKr78aCBsQo3"
      },
      "outputs": [],
      "source": [
        "waveform = wav_data / tf.int16.max"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e_Xwd4GPuMsB"
      },
      "source": [
        "## Executing the Model\n",
        "\n",
        "Now the easy part: using the data already prepared, you just call the model and get the: scores, embedding and the spectrogram.\n",
        "\n",
        "The score is the main result you will use.\n",
        "The spectrogram you will use to do some visualizations later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BJGP6r-At_Jc"
      },
      "outputs": [],
      "source": [
        "# Run the model, check the output.\n",
        "scores, embeddings, spectrogram = model(waveform)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vmo7griQprDk"
      },
      "outputs": [],
      "source": [
        "scores_np = scores.numpy()\n",
        "spectrogram_np = spectrogram.numpy()\n",
        "infered_class = class_names[scores_np.mean(axis=0).argmax()]\n",
        "print(f'The main sound is: {infered_class}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uj2xLf-P_ndS"
      },
      "source": [
        "## Visualization\n",
        "\n",
        "YAMNet also returns some additional information that we can use for visualization.\n",
        "Let's take a look on the Waveform, spectrogram and the top classes inferred."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_QSTkmv7wr2M"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10, 6))\n",
        "\n",
        "# Plot the waveform.\n",
        "plt.subplot(3, 1, 1)\n",
        "plt.plot(waveform)\n",
        "plt.xlim([0, len(waveform)])\n",
        "\n",
        "# Plot the log-mel spectrogram (returned by the model).\n",
        "plt.subplot(3, 1, 2)\n",
        "plt.imshow(spectrogram_np.T, aspect='auto', interpolation='nearest', origin='lower')\n",
        "\n",
        "# Plot and label the model output scores for the top-scoring classes.\n",
        "mean_scores = np.mean(scores, axis=0)\n",
        "top_n = 10\n",
        "top_class_indices = np.argsort(mean_scores)[::-1][:top_n]\n",
        "plt.subplot(3, 1, 3)\n",
        "plt.imshow(scores_np[:, top_class_indices].T, aspect='auto', interpolation='nearest', cmap='gray_r')\n",
        "\n",
        "# patch_padding = (PATCH_WINDOW_SECONDS / 2) / PATCH_HOP_SECONDS\n",
        "# values from the model documentation\n",
        "patch_padding = (0.025 / 2) / 0.01\n",
        "plt.xlim([-patch_padding-0.5, scores.shape[0] + patch_padding-0.5])\n",
        "# Label the top_N classes.\n",
        "yticks = range(0, top_n, 1)\n",
        "plt.yticks(yticks, [class_names[top_class_indices[x]] for x in yticks])\n",
        "_ = plt.ylim(-0.5 + np.array([top_n, 0]))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "yamnet.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
